import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch_sparse import SparseTensor
from torch_geometric.utils import (to_dense_batch, 
                                   add_remaining_self_loops, 
                                   dropout_edge, 
                                   scatter)
from torch_geometric.utils.num_nodes import maybe_num_nodes


from model.layers import SquarePathIntegral, ClassCommunicator
from model.layer_zoo import GNNLayerZoo
from utils.evaluators import linear_sum_assignment


class FeatAgnosticLinearLearning(nn.Module):
    def __init__(self, n_q, q_dim, dropout, pretrain, mds_dim, randin=False):
        super(FeatAgnosticLinearLearning, self).__init__()
        self.intersub_ff = nn.Sequential()
        self.mds_dim = mds_dim
        self.q_dim = q_dim
        if pretrain and not randin:
            self.dim_rel = SquarePathIntegral(1, n_q * 4, q_dim - self.mds_dim)
            for t in range(4):
                self.intersub_ff.add_module(f'linear{str(t)}', nn.Linear(q_dim, q_dim))
                self.intersub_ff.add_module(f'relu{str(t)}', nn.LeakyReLU())
                self.intersub_ff.add_module(f'dropout{str(t)}', nn.Dropout(dropout))
            self.intersub_ff.add_module(f'linear', nn.Linear(q_dim, q_dim))
        self.pretrain = pretrain
        self.randin = randin

    
    def forward(self, messages, messages_mds=None, in_subsystem=None, isgraph=False):
        messages = (messages - messages.min()) / (messages.max() - messages.min())
        n, d_in = messages.shape[-2:]
        if not isgraph:
            m_mean = messages.mean(-2, keepdim=True)
            m_var = messages.var(-2, keepdim=True)
            messages = (messages - m_mean) / torch.sqrt(m_var + 1e-5)

        if self.pretrain and not self.randin:
            messages_rel_in = messages.unsqueeze(-3).transpose(-2, -1) / math.sqrt(n)
            messages_insp = messages_rel_in.sum(-2, keepdim=True) / d_in
            in_subsystem_init = self.dim_rel(messages_insp, messages_rel_in).transpose(-2, -1)

            if self.mds_dim > 0:
                in_subsystem_init = torch.concat((in_subsystem_init, messages_mds), dim=-1)

            in_subsystem_kv = torch.matmul(messages, in_subsystem_init)
            in_subsystem_no = torch.matmul(messages.transpose(-2, -1), in_subsystem_kv)
            in_subsystem_de = torch.matmul(messages.transpose(-2, -1), messages.sum(-1, keepdim=True))
            in_subsystem = in_subsystem_no / (in_subsystem_de + 1e-6)
            in_subsystem = self.intersub_ff(in_subsystem)
        elif self.randin:
            in_subsystem = torch.rand(d_in, self.q_dim, device=messages.device)
        
        messages = F.layer_norm(messages, (d_in,))
        messages = torch.matmul(messages, in_subsystem)
        return messages



class ClassAgnosticLinearLearning(nn.Module):
    def __init__(self, n_q, d_in, q_dim, dropout):
        super(ClassAgnosticLinearLearning, self).__init__()
        self.class_neuron_ff = nn.Sequential()
        self.class_neuron_ff.add_module(f'linear', nn.Linear(d_in, q_dim))
        for t in range(2):
            self.class_neuron_ff.add_module(f'relu{str(t)}', nn.LeakyReLU())
            self.class_neuron_ff.add_module(f'dropout{str(t)}', nn.Dropout(dropout))
            self.class_neuron_ff.add_module(f'linear{str(t)}', nn.Linear(q_dim, q_dim))
        self.collection = SquarePathIntegral(q_dim, n_q)
        self.chidden_norm = nn.LayerNorm(q_dim)
        self.class_agg = ClassCommunicator(q_dim, q_dim, n_q, dropout)

    
    def forward(self, node_state_avg, class_state):
        class_state = self.class_neuron_ff(class_state)
        glob_init = node_state_avg - class_state
        disp = self.class_agg(class_state, glob_init.squeeze(0))
        class_state = disp + class_state
        class_state = (class_state - class_state.mean(dim=-2, keepdim=True)) / (torch.sqrt(class_state.std(dim=-2, keepdim=True)) + 1e-6)
        return class_state



def sce_loss(x, y, alpha=3):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)

    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)

    loss = loss.mean()
    return loss



from functools import partial
class AgnosticGNN(nn.Module):
    def __init__(self, *, d_in=1, 
                          d_ein=0, 
                          nclass=1, 
                          d_model=64, 
                          q_dim=64, 
                          mds_dim=2,
                          n_q=8, 
                          n_pnode=256, 
                          randin=False,
                          T=1, 
                          pretrain=True,
                          task_type="single-class",
                          self_loop=True,
                          pre_encoder=None, 
                          pos_encoder=None, 
                          dropout=0.1):
        super(AgnosticGNN, self).__init__()
        self.in_aglearner = FeatAgnosticLinearLearning(n_q, q_dim, dropout, pretrain, mds_dim, randin)
        self.out_in_aglearner = FeatAgnosticLinearLearning(n_q, q_dim, dropout, pretrain, 0, randin)
        self.init_dim = 16
        if pretrain:
            self.register_buffer('class_neuron', self.angle_positional_encoding(nclass, self.init_dim))
            self.out_out_aglearner = ClassAgnosticLinearLearning(n_q, self.init_dim, q_dim, dropout)
            self.feat_neuron = None
        else:
            self.class_neuron = nn.Parameter(torch.randn(1, nclass, q_dim))
            self.out_out_aglearner = None
            self.feat_neuron = nn.Parameter(torch.randn(d_in + d_ein, q_dim))
        self.node_state_interface = nn.Sequential(nn.Linear(q_dim, q_dim),
                                                  nn.LeakyReLU(),
                                                  nn.Dropout(dropout))
        self.pre_encoder = pre_encoder 
        self.pos_encoder = pos_encoder 
        self.T = T
        self.n_q = n_q
        self.mds_dim = mds_dim
        self.q_dim = q_dim
        self.n_pnode = n_pnode
        self.d_model = d_model
        self.pretrain = pretrain
        self.task_type = task_type
        self.self_loop = self_loop
        

    def angle_positional_encoding(self, nclass, final_q_dim):
        # Compute the number of points per angle dimension (resolution for each angle dimension)
        q_dim = min(max(2, math.ceil(math.log2(nclass))), final_q_dim)
        num_per_dim = int(math.ceil(nclass**(1 / (q_dim - 1)))) + 2  # Approximate resolution for each angle dimension
        if num_per_dim <= 3:
            num_per_dim = 3
        grid = torch.linspace(0, torch.pi, num_per_dim)[1:-1]  # Uniformly distributed angles along (0, π)

        # Generate all combinations of angles (cartesian product for dim-1 dimensions)
        meshes = torch.meshgrid(*([grid] * (q_dim - 1)), indexing='ij')
        theta = torch.stack(meshes, dim=-1)

        # Compute the sine and cosine of all angles
        sin_theta = torch.sin(theta)
        cos_theta = torch.cos(theta)

        points = torch.zeros((*theta.shape[:-1], q_dim))
        sin_product = torch.ones_like(sin_theta[..., 0])
        for i in range(q_dim - 1):
            points[..., i] = sin_product * cos_theta[..., i]
            sin_product *= sin_theta[..., i]
        points[..., -1] = sin_product

        # Keep only the desired number of points
        points = points.reshape(-1, q_dim)
        if points.size(0) > nclass:  
            points = points[:nclass]
        if q_dim < final_q_dim:
            points = torch.concat((points, points.new_zeros(nclass, final_q_dim - q_dim)), -1)
        return points.unsqueeze(0)


    def __get_sparse_normalized_adj(self, *, edge_index=None, max_num_nodes=None, edge_weight=None, batch=None):
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
        if edge_weight.dtype == torch.long:
            edge_weight = edge_weight.type(torch.float32)

        # normalize edge weight
        row, col = edge_index[0], edge_index[1]
        deg = scatter(edge_weight, row, 0, 
                      dim_size=maybe_num_nodes(edge_index), 
                      reduce='sum')
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 

        # batch gen
        if batch is None:
            num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
            batch = edge_index.new_zeros(num_nodes)
        batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1

        # transform edge index into batched index with padding
        one = batch.new_ones(batch.size(0))
        num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')
        cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

        idx0 = batch[edge_index[0]]
        idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
        idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]

        if ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)
            or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):
            mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
            idx0 = idx0[mask]
            idx1 = idx1[mask]
            idx2 = idx2[mask]
            edge_weight = edge_weight[mask]
        
        idx = torch.stack((idx1, idx2), 0)
        idx = idx0 * max_num_nodes + idx
        return idx, edge_weight


    def feature_prep(self, data, batch_type="padding"):
        features = data.x
        mask = None
        whole_size = features.shape[-2]
        if features.dtype == torch.long:
            features = features.type(torch.float32)
        edge_index, edge_attr = add_remaining_self_loops(data.edge_index, 
                                                         data.edge_attr, 
                                                         num_nodes=whole_size)

        if 'laplacian_eigenvector_pe' in data.keys:
            features = torch.concat((features, data.laplacian_eigenvector_pe), -1)
        if self.pre_encoder is not None:
            features = self.pre_encoder(features)
        if self.pos_encoder is not None:
            features = self.pos_encoder(features, data.rrwp)
        
        if edge_attr is not None:
            edge_attr = scatter(edge_attr, edge_index[0], 0, 
                                dim_size=maybe_num_nodes(edge_index), 
                                reduce='sum')
        if batch_type == "padding":
            if data.batch.max() > 0:
                features, mask = to_dense_batch(features, data.batch)
            if features.ndim == 2:
                features = features.unsqueeze(0)
            b_s, n = features.shape[:2]
            if edge_attr is not None:
                if data.batch.max() > 0:
                    edge_attr, _ = to_dense_batch(edge_attr, data.batch)
                edge_attr = edge_attr.view(b_s, n, -1)
            edge_index, edge_weight = self.__get_sparse_normalized_adj(edge_index=edge_index, 
                                                                       max_num_nodes=n,
                                                                       batch=data.batch)
            size = features.shape[-2] if data.batch is None else (data.batch.max() + 1) * features.shape[-2]
        elif batch_type == "padding-features":
            if data.batch.max() > 0:
                features, mask = to_dense_batch(features, data.batch)
            if features.ndim == 2:
                features = features.unsqueeze(0)
            b_s, n = features.shape[:2]
            if edge_attr is not None:
                if data.batch.max() > 0:
                    edge_attr, _ = to_dense_batch(edge_attr, data.batch)
                edge_attr = edge_attr.view(b_s, n, -1)
            edge_weight = None
            size = features.shape[-2] if data.batch is None else (data.batch.max() + 1) * features.shape[-2]
        elif batch_type == "flatten":
            edge_weight = None
            size = features.shape[-2]
            if edge_attr is not None:
                edge_attr = edge_attr.view(features.shape[0], -1)
        else:
            ValueError(f"Unexpected batch type: {batch_type}")

            
        return features, edge_index, edge_weight, size, mask


    def get_output_mapping(self, matching_class_neuron, class_state, class_mds=None):
        matching_class_neuron = self.in_aglearner(matching_class_neuron, class_mds, self.feat_neuron, False)
        avg_class_state = class_state.mean(dim=-2, keepdim=True)
        if avg_class_state.ndim > 2:
            avg_class_state = avg_class_state.view(-1, 1, self.q_dim).mean(0)
        avg_class_neuron = matching_class_neuron.mean(dim=-2, keepdim=True)
        label_matching = torch.cdist(avg_class_neuron.T, avg_class_state.T, p=1)
        row_ind, col_ind = linear_sum_assignment(label_matching.detach().cpu().numpy())
        permutation = torch.sparse_coo_tensor(indices=torch.stack((torch.from_numpy(row_ind), 
                                                                    torch.from_numpy(col_ind)), 0).cuda(),
                                            values=label_matching.new_ones(label_matching.shape[-1]),
                                            size=(label_matching.shape[-1], label_matching.shape[-1]))
        matching_class_neuron = torch.matmul(matching_class_neuron, permutation)
        if class_state.ndim > 2:
            class_state = class_state.view(-1, class_state.shape[-2], self.q_dim).mean(0, keepdim=True)
        label_matching = torch.cdist(class_state.squeeze(0), matching_class_neuron, p=1)
        row_ind, col_ind = linear_sum_assignment(label_matching.detach().cpu().numpy())
        output_mapping = torch.sparse_coo_tensor(indices=torch.stack((torch.from_numpy(row_ind), 
                                                                    torch.from_numpy(col_ind)), 0).cuda(),
                                            values=label_matching.new_ones(label_matching.shape[-1]),
                                            size=(label_matching.shape[-1], label_matching.shape[-1]))
        return output_mapping


    def get_output(self, features, task_type=None):
        if task_type is None:
            task_type = self.task_type
        if "single-class" in task_type:
            if features.shape[-1] > 2 or (self.pretrain and features.shape[-1] > 1):
                features = F.layer_norm(features, (features.shape[-1],))
            output = F.log_softmax(features, dim=-1)
        elif task_type == "multi-class":
            features = features.unflatten(-1, (-1, 2))
            features = F.layer_norm(features, (features.shape[-1],))
            output = F.log_softmax(features, dim=-1)[..., 1]
        elif task_type == "reg":
            output = features
        elif "link" in task_type:
            outdim = features.shape[-1] // 2
            node_in = features[:, :outdim]
            node_out = features[:, outdim:]
            if "scale-dot" in task_type:
                output = torch.matmul(node_in, node_out.T) / outdim
            elif "cosine" in task_type:
                norm_in = torch.norm(node_in, dim=-1)
                norm_out = torch.norm(node_out, dim=-1)
                output = torch.matmul(node_in, node_out.T) / (norm_in * norm_out)
            output = output * 2
        else:
            raise ValueError("Unsupported task type " + self.task_type)
        return output


    def forward(self):
        pass



class GNNZoo(AgnosticGNN):
    def __init__(self, *, d_in=1, 
                          d_ein=0, 
                          nclass=1, 
                          d_model=64, 
                          q_dim=64, 
                          mds_dim=2,
                          n_q=8, 
                          backbone="GCN", 
                          n_pnode=256, 
                          T=1, 
                          nh=1,
                          pretrain=True,
                          sage_aggr="mean",
                          task_type="single-class",
                          self_loop=True,
                          randin=False,
                          using_residual=False,
                          pre_encoder=None, 
                          pos_encoder=None, 
                          dropout=0.1):
        super(GNNZoo, self).__init__(d_in=d_in, 
                                     d_ein=d_ein,
                                     nclass=nclass, 
                                     d_model=d_model, 
                                     q_dim=q_dim, 
                                     mds_dim=mds_dim,
                                     n_q=n_q, 
                                     n_pnode=n_pnode, 
                                     T=T, 
                                     randin=randin,
                                     pretrain=pretrain,
                                     task_type=task_type,
                                     pre_encoder=pre_encoder, 
                                     pos_encoder=pos_encoder, 
                                     self_loop=self_loop,
                                     dropout=dropout)
        self.using_residual = using_residual
        self._drop_edge_rate = 0.8
        self._mask_rate = 0.5
        self._replace_rate = 0.05
        self._mask_token_rate = 1 - self._replace_rate
        self.criterion = partial(sce_loss, alpha=3)
        self.enc_mask_token = nn.Parameter(torch.zeros(1, q_dim))

        print("Using residual: " + str(using_residual))
        self.node_state_updater = nn.ModuleList([])
        powers = []
        if backbone == "MixHop":
            print("MixHop: Forced single layer with powers!")
            powers = [i for i in range(self.T)]
            self.T = 1
        for _ in range(self.T):
            self.node_state_updater.append(GNNLayerZoo(d_in=q_dim, 
                                                       d_out=q_dim, 
                                                       backbone=backbone, 
                                                       nh=nh, 
                                                       powers=powers,
                                                       sage_aggr=sage_aggr, 
                                                       nlinear=True, 
                                                       using_residual=False,
                                                       dropout=dropout))
        self.encoder_to_decoder = nn.Linear(q_dim, q_dim, bias=False)
        self._drop_edge_rate = 0.5


    def encoding_mask_noise(self, x, mask_rate=0.3):
        num_nodes = x.shape[0]
        perm = torch.randperm(num_nodes, device=x.device)
        num_mask_nodes = int(mask_rate * num_nodes)

        # random masking
        num_mask_nodes = int(mask_rate * num_nodes)
        mask_nodes = perm[: num_mask_nodes]
        keep_nodes = perm[num_mask_nodes: ]

        if self._replace_rate > 0:
            num_noise_nodes = int(self._replace_rate * num_mask_nodes)
            perm_mask = torch.randperm(num_mask_nodes, device=x.device)
            token_nodes = mask_nodes[perm_mask[: int(self._mask_token_rate * num_mask_nodes)]]
            noise_nodes = mask_nodes[perm_mask[-int(self._replace_rate * num_mask_nodes):]]
            noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]

            out_x = x.clone()
            out_x[token_nodes] = 0.0
            out_x[noise_nodes] = x[noise_to_be_chosen]
        else:
            out_x = x.clone()
            token_nodes = mask_nodes
            out_x[mask_nodes] = 0.0

        out_x[token_nodes] += self.enc_mask_token

        return out_x, (mask_nodes, keep_nodes)
    

    def get_embed(self, data, nclass=None, isgraph=False, sample=False):
        messages, \
            edge_index, \
            edge_weight, \
            size, mask = self.feature_prep(data, "padding-features")
        if sample:
            self.edge_index = edge_index
            edge_index, masked_edges = dropout_edge(edge_index, self._drop_edge_rate)
            edge_weight = edge_weight if edge_weight is None else edge_weight[masked_edges]
            self.used_edge_index = edge_index
        messages = self.in_aglearner(messages, data.mds if self.mds_dim > 0 else None, self.feat_neuron, isgraph)
        node_state = self.node_state_interface(messages).flatten(0, 1)
        if mask is not None:
            node_state = node_state[mask.flatten()]
        edge_index = SparseTensor.from_edge_index(
            edge_index, sparse_sizes=(node_state.shape[-2], node_state.shape[-2])
        ).to(edge_index.device)

        for t in range(self.T):
            node_state = self.node_state_updater[t](node_state, 
                                                    edge_index, 
                                                    edge_weight,
                                                    data.batch)
        if self.using_residual:
            node_state = F.layer_norm(node_state, (self.q_dim,))
        if data.batch.max() > 0 or isgraph:
            node_state = scatter(node_state, data.batch, 0, reduce="mean")
            node_state_avg = node_state.unsqueeze(-2)
            node_state = node_state.view(node_state.shape[0], 1, -1)
        else:
            node_state_avg = node_state.mean(-2, keepdim=True).unsqueeze(-2)
        self.node_state_avg = node_state_avg
        return node_state


    def forward(self, data, nclass=None, multi_label_flag=False, isgraph=False):
        node_state = self.get_embed(data, nclass, isgraph, sample=self.pretrain)
        
        if nclass is not None and self.pretrain:
            self.class_neuron = self.angle_positional_encoding(nclass, self.init_dim).cuda()
        class_state = self.out_out_aglearner(self.node_state_avg, 
                                            self.class_neuron) \
                            if self.pretrain \
                            else self.class_neuron
        messages = torch.matmul(node_state, class_state.squeeze(0).transpose(-2, -1))
        messages = messages.view(-1, messages.shape[-1])
        task_type = "multi-class" if multi_label_flag else None
        return self.get_output(messages, task_type)
    

    def prediction(self, data, output_mapping=None, nclass=None, multi_label_flag=False, isgraph=False, matching_class_neuron=None):
        task_type = "multi-class" if multi_label_flag else None
        node_state = self.get_embed(data, nclass, isgraph)
        if nclass is not None and self.pretrain:
            self.class_neuron = self.angle_positional_encoding(nclass, self.init_dim).cuda()

        class_state = self.out_out_aglearner(self.node_state_avg, 
                                             self.class_neuron) \
                            if self.pretrain \
                            else self.class_neuron
        messages = torch.matmul(node_state.unsqueeze(-3), class_state.squeeze(0).transpose(-2, -1))
        messages = messages.view(-1, messages.shape[-1])
        if matching_class_neuron is not None:
            output_mapping = self.get_output_mapping(matching_class_neuron, class_state, data.class_mds[0] if self.mds_dim > 0 else None)
        if output_mapping is None:
            return messages
        else:
            messages = torch.matmul(messages, output_mapping)
            return self.get_output(messages, task_type)